# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import csv
import numpy as np
import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion

m0 = 0
sentence_kl_loss = {}
sentence_kl_loss_CL = {}
max_kl_loss = 0.0
max_kl_loss_epoch = 0.0
min_kl_loss = 0.0
sample_kl_loss_dict = {}
sample_kl_loss_dict_pro = {}
augmentation_masking_probability = 0.10
last_model_train_competence = 0
last_model_train_competence_kl_loss = 0
dataset_len = 0
sample_kl_loss_dict_new = {}
sample_kl_loss_dict = {}


def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.)
        smooth_loss.masked_fill_(pad_mask, 0.)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)
    if reduce:
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
    return loss, nll_loss


def vanilla_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.)
    else:
        nll_loss = nll_loss.squeeze(-1)
    if reduce:
        nll_loss = nll_loss.sum()
    return nll_loss


@register_criterion('reg_label_smoothed_cross_entropy')
class RegLabelSmoothedCrossEntropyCriterion(FairseqCriterion):

    def __init__(self, task, sentence_avg, label_smoothing):
        super().__init__(task)
        self.sentence_avg = sentence_avg
        self.eps = label_smoothing

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        # fmt: on

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1, 1)

        loss, nll_loss = label_smoothed_nll_loss(
            lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
        )
        return loss, nll_loss

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
        metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': loss.data,
            'nll_loss': nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output

    def compute_kl_loss(self, model, net_output, pad_mask=None, reduce=False):
        CL_type = 0
        net_prob = model.get_normalized_probs(net_output, log_probs=True)
        net_prob_tec = model.get_normalized_probs(net_output, log_probs=False)

        p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0)
        p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0)

        p_loss = torch.nn.functional.kl_div(p, q_tec, reduction='none')
        q_loss = torch.nn.functional.kl_div(q, p_tec, reduction='none')

        if pad_mask is not None:
            p_loss.masked_fill_(pad_mask, 0.)
            q_loss.masked_fill_(pad_mask, 0.)

        sentence_logit = self.sentence_embedding(net_output[0], pad_mask)

        sentence_logit1 = torch.log_softmax(sentence_logit, dim=-1)
        sentence_logit2 = torch.softmax(sentence_logit, dim=-1)

        sentence_p, sentence_q = torch.split(sentence_logit1, sentence_logit1.size(0) // 2, dim=0)
        sentence_p_tec, sentence_q_tec = torch.split(sentence_logit2, sentence_logit2.size(0) // 2, dim=0)

        sentence_p_loss = torch.nn.functional.kl_div(sentence_p, sentence_q_tec, reduction='none')
        sentence_q_loss = torch.nn.functional.kl_div(sentence_q, sentence_p_tec, reduction='none')

        if CL_type == 1:
            p_loss = p_loss.sum()
            q_loss = q_loss.sum()
            token_loss = (p_loss + q_loss) / 2
            sentence_p_loss = sentence_p_loss.sum()
            sentence_q_loss = sentence_q_loss.sum()
            sentence_loss = (sentence_p_loss + sentence_q_loss) / 2
            total_loss = token_loss + sentence_loss
            total_loss = total_loss.sum()
        elif CL_type == 2:
            test_p = p_loss.sum(dim=-1)
            test_p = test_p.sum(dim=-1)
            test_q = q_loss.sum(dim=-1)
            test_q = test_q.sum(dim=-1)
            token_loss = (test_p + test_q) / 2
            test_sentence_p_loss = sentence_p_loss.sum(dim=-1)
            test_sentence_q_loss = sentence_q_loss.sum(dim=-1)
            test_sentence_p_loss = torch.exp(test_sentence_p_loss)
            test_sentence_q_loss = torch.exp(test_sentence_q_loss)
            sentence_loss = (test_sentence_p_loss + test_sentence_q_loss) / 2
            total_loss = token_loss * sentence_loss
            total_loss = total_loss.sum()
        elif CL_type == 3:
            test_p = p_loss.sum(dim=-1)
            test_p = test_p.sum(dim=-1)
            test_q = q_loss.sum(dim=-1)
            test_q = test_q.sum(dim=-1)
            test_sentence_p_loss = sentence_p_loss.sum(dim=-1)
            test_sentence_q_loss = sentence_q_loss.sum(dim=-1)
            test_sentence_p_loss = torch.exp(test_sentence_p_loss)
            test_sentence_q_loss = torch.exp(test_sentence_q_loss)
            total_loss = (test_sentence_p_loss * test_p + test_sentence_q_loss * test_q) / 2
            total_loss = total_loss.sum()
        else:
            p_loss = p_loss.sum()
            q_loss = q_loss.sum()
            total_loss = (p_loss + q_loss) / 2

        return total_loss

    def _replace_token(self, inputs, masking_indices, mask_index, vocab_size):
        inputs[masking_indices] = mask_index
        return inputs

    def CipherDAug(self, inputs, vocab_dict, augmentation_masking_probability = 0.1):
        vocab_size = len(vocab_dict)
        bos_index, eos_index = vocab_dict.bos(), vocab_dict.eos()
        pad_index, unk_index = vocab_dict.pad(), vocab_dict.unk()
        available_token_indices = (inputs != bos_index) & (inputs != eos_index) & (inputs != pad_index) & (
                inputs != unk_index)
        random_masking_indices = torch.bernoulli(
            torch.full(inputs.shape, augmentation_masking_probability, device=inputs.device)).bool()
        masked_inputs = inputs.clone()
        masking_indices = random_masking_indices & available_token_indices
        masked_inputs = self._replace_token(masked_inputs, masking_indices, unk_index, vocab_size)
        return masked_inputs

    def sentence_embedding(self, net_output, pad_mask):
        mask = (~pad_mask).float()
        mask = torch.cat([mask, mask.clone()], 0)
        decoder_embedding = (net_output * mask).sum(dim=1) / mask.float().sum(dim=1)  # [batch, hidden_size]
        return decoder_embedding

    def compute_regularization_loss(self, model, net_output, pad_mask=None, reduce=False):
        CL_type = 2
        # token_level
        net_prob = model.get_normalized_probs(net_output, log_probs=True)
        net_prob_tec = model.get_normalized_probs(net_output, log_probs=False)
        p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0)
        p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0)
        m = (p_tec + p_tec) / 2
        p_loss = torch.nn.functional.kl_div(p, m, reduction='none')
        q_loss = torch.nn.functional.kl_div(q, m, reduction='none')
        if pad_mask is not None:
            p_loss.masked_fill_(pad_mask, 0.)
            q_loss.masked_fill_(pad_mask, 0.)

        # sentence_level
        sentence_logit = self.sentence_embedding(net_output[0], pad_mask)
        sentence_logit1 = torch.log_softmax(sentence_logit, dim=-1)
        sentence_logit2 = torch.softmax(sentence_logit, dim=-1)
        sentence_p, sentence_q = torch.split(sentence_logit1, sentence_logit1.size(0) // 2, dim=0)
        sentence_p_tec, sentence_q_tec = torch.split(sentence_logit2, sentence_logit2.size(0) // 2, dim=0)
        m = (sentence_p_tec + sentence_q_tec) / 2
        sentence_p_loss = torch.nn.functional.kl_div(sentence_p, m, reduction='none')
        sentence_q_loss = torch.nn.functional.kl_div(sentence_q, m, reduction='none')
        if CL_type == 1:
            p_loss = p_loss.sum()
            q_loss = q_loss.sum()
            token_loss = (p_loss + q_loss) / 2
            sentence_p_loss = sentence_p_loss.sum()
            sentence_q_loss = sentence_q_loss.sum()
            sentence_loss = (sentence_p_loss + sentence_q_loss) / 2
            total_loss = token_loss + sentence_loss
            total_loss = total_loss.sum()
        elif CL_type == 2:
            test_p = p_loss.sum(dim=-1)
            test_p = test_p.sum(dim=-1)
            test_q = q_loss.sum(dim=-1)
            test_q = test_q.sum(dim=-1)
            token_loss = (test_p + test_q) / 2
            test_sentence_p_loss = sentence_p_loss.sum(dim=-1)
            test_sentence_q_loss = sentence_q_loss.sum(dim=-1)
            test_sentence_p_loss = torch.exp(test_sentence_p_loss)
            test_sentence_q_loss = torch.exp(test_sentence_q_loss)
            sentence_loss = (test_sentence_p_loss + test_sentence_q_loss) / 2
            total_loss = token_loss * sentence_loss
            total_loss = total_loss.sum()
        elif CL_type == 3:
            test_p = p_loss.sum(dim=-1)
            test_p = test_p.sum(dim=-1)
            test_q = q_loss.sum(dim=-1)
            test_q = test_q.sum(dim=-1)
            test_sentence_p_loss = sentence_p_loss.sum(dim=-1)
            test_sentence_q_loss = sentence_q_loss.sum(dim=-1)
            test_sentence_p_loss = torch.exp(test_sentence_p_loss)
            test_sentence_q_loss = torch.exp(test_sentence_q_loss)
            total_loss = (test_sentence_p_loss * test_p + test_sentence_q_loss * test_q) / 2
            total_loss = total_loss.sum()
        else:
            p_loss = p_loss.sum()
            q_loss = q_loss.sum()
            total_loss = (p_loss + q_loss) / 2

        return total_loss

    def forward_reg(self, model, sample, optimizer, reg_alpha, ignore_grad, reduce=True): #CCL
        data_aug = False
        CL_case_3 = True
        sample_input = sample['net_input']
        if data_aug:
            augmented_sample = {}
            augmented_sample['src_tokens'] = self.CipherDAug(sample['net_input']['src_tokens'],
                                                                       sample['scr_dict'])
            sample_concat_input = {
                'src_tokens': torch.cat([sample_input['src_tokens'], augmented_sample['src_tokens'].clone()], 0),
                'src_lengths': torch.cat([sample_input['src_lengths'], sample_input['src_lengths'].clone()], 0),
                'prev_output_tokens': torch.cat(
                    [sample_input['prev_output_tokens'], augmented_sample['prev_output_tokens'].clone()], 0),
            }
        else:
            sample_concat_input = {
                'src_tokens': torch.cat([sample_input['src_tokens'], sample_input['src_tokens'].clone()], 0),
                'src_lengths': torch.cat([sample_input['src_lengths'], sample_input['src_lengths'].clone()], 0),
                'prev_output_tokens': torch.cat(
                    [sample_input['prev_output_tokens'], sample_input['prev_output_tokens'].clone()], 0),
            }
        net_output = model(**sample_concat_input)
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output)
        pad_mask = target.unsqueeze(-1).eq(self.padding_idx)
        target = torch.cat([target, target.clone()], dim=0)


        kl_loss = self.compute_kl_loss(model, net_output, pad_mask)  # kl_loss
        last_model_train_competence_kl_loss += kl_loss.sum()

        import csv
        loss_total = 0
        s_length = len(lprobs.size()[0]) / 2
        dataset_len += s_length
        for inedx in len(lprobs.size()[0]) / 2
            lprobs_sample = torch.cat([lprobs[index], lprobs[index + s_length]], dim=0)
            target_sample = torch.cat([target_sample[index], target_sample[index + s_length]], dim=0)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs_sample, target_sample.view(-1, 1), self.eps, ignore_index=self.padding_idx, reduce=reduce,
            )
            sample_feature_sample = str(sample_input['src_tokens'][index])[:50]
            loss += reg_alpha * (kl_loss[inedx] + kl_loss[index + s_length]) / 2
            sample_kl_loss_dict_new[sample_feature_sample] = (kl_loss[inedx] + kl_loss[index + s_length]) / 2
            if CL_case_3:
                model_competence_valid = sample['kl_loss_res_globa']
                sample_kl_loss = sample_kl_loss_dict[sample_feature_sample]
                if sample['epoch_itr.epoch'] > 1:
                    lambda_r = 10
                    model_competence = min(1, abs(model_competence_valid - last_model_train_competence) / lambda_r)
                    beta = math.exp(1 - abs(model_competence - (sample_kl_loss-min_kl_loss)/max_kl_loss) - 1
                    loss = loss * beta
                    loss_total += loss
                if sample['end_of_epoch']:
                    last_model_train_competence = last_model_train_competence_kl_loss / dataset_len
                    last_model_train_competence_kl_loss = 0
                    dataset_len = 0
                    sample_kl_loss_dict = sample_kl_loss_dict_new
                    sample_kl_loss_dict_sort = sorted(sample_kl_loss_dict.items(), key=lambda kv: (kv[1], kv[0]))
                    min_kl_loss = float(sample_kl_loss_dict_sort[0][1])
                    max_kl_loss = float(sample_kl_loss_dict_sort[len(sample_kl_loss_dict_sort) - 1][1])
                else:
                    loss_total += loss

        with torch.autograd.profiler.record_function("backward"):
            optimizer.backward(loss_total)
        ntokens = sample['ntokens']
        nsentences = sample['target'].size(0)
        sample_size = sample['ntokens']
        logging_output = {
            'loss': utils.item(loss_total.data) if reduce else loss_total.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
